{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Solving an overdetermined system using  pseudo inverse\n",
    "\n",
    "Consider the overdetermined system corresponding to cat-brain from Chapter 2.\n",
    "\n",
    "There are 15 training examples, each with input and desired outputs specified.\n",
    "\n",
    "Our goal is to determine 3 unkwnowns (w0, w1, b).\n",
    "\n",
    "This can be cast as an over-determined system of equations\n",
    "$$\n",
    "A\\vec{w} = \\vec{y}\n",
    "$$\n",
    "where\n",
    "$$ \n",
    "A =\n",
    "\\begin{bmatrix}\n",
    "        0.11 & 0.09 & 1.00 \\\\\n",
    "        0.01 & 0.02 & 1.00 \\\\\n",
    "        0.98 & 0.91 & 1.00 \\\\\n",
    "        0.12 & 0.21 & 1.00 \\\\\n",
    "        0.98 & 0.99 & 1.00 \\\\\n",
    "        0.85 & 0.87 & 1.00 \\\\\n",
    "        0.03 & 0.14 & 1.00 \\\\\n",
    "        0.55 & 0.45 & 1.00 \\\\\n",
    "        0.49 & 0.51 & 1.00 \\\\\n",
    "        0.99 & 0.01 & 1.00 \\\\\n",
    "        0.02 & 0.89 & 1.00 \\\\\n",
    "        0.31 & 0.47 & 1.00 \\\\\n",
    "        0.55 & 0.29 & 1.00 \\\\\n",
    "        0.87 & 0.76 & 1.00 \\\\\n",
    "        0.63 & 0.24 & 1.00\n",
    "\\end{bmatrix}\n",
    "\\;\\;\\;\\;\\;\\;\\;\n",
    "\\vec{y} = \n",
    "\\begin{bmatrix}\n",
    "        -0.8 \\\\\n",
    "        -0.97 \\\\\n",
    "        0.89 \\\\ \n",
    "        -0.67 \\\\ \n",
    "        0.97 \\\\ \n",
    "        0.72 \\\\ \n",
    "        -0.83 \\\\ \n",
    "        0.00 \\\\\n",
    "        0.00 \\\\\n",
    "        0.00 \\\\\n",
    "        -0.09 \\\\\n",
    "        -0.22 \\\\ \n",
    "        -0.16 \\\\\n",
    "        0.63 \\\\\n",
    "        0.37\n",
    "\\end{bmatrix}\n",
    "\\;\\;\\;\\;\\;\\;\\;\n",
    "\\vec{w} = \\begin{bmatrix} w_{0}\\\\w_{1}\\\\b\\end{bmatrix}\n",
    "$$\n",
    "\n",
    "We solve for $\\vec{w}$ using the pseudo inverse formula $\\space\\space\\large{\\vec{w} = (A^TA)^{-1}A^Ty}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The solution is tensor([ 1.0766,  0.8976, -0.9582])\n",
      "Note that this is almost equal to [1.0, 1.0, -1.0])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "# Let us revisit our cat brain data set\n",
    "# Notice that there are 15 training examples, with 3\n",
    "# unkwnowns (w0, w1, b).\n",
    "# This is an over determined system.\n",
    "# It can be easily seen that the solution is roughly\n",
    "# $w_{0} = 1, w_{1} = 1, b = -1$.\n",
    "# It has been deliberately chosen as such.\n",
    "# But the equations are not fully consistent (i.e., there is\n",
    "# no solution that satisfies all the equations).\n",
    "# We want to find the best values such that it minimizes Aw - b.\n",
    "# This is what the pseudo-inverse does.\n",
    "\n",
    "def pseudo_inverse(A):\n",
    "    return torch.matmul(torch.linalg.inv(torch.matmul(A.T, A)), A.T)\n",
    "\n",
    "# The usual cat-brain input dataset\n",
    "X = torch.tensor([[0.11, 0.09], [0.01, 0.02], [0.98, 0.91], [0.12, 0.21],\n",
    "              [0.98, 0.99], [0.85, 0.87], [0.03, 0.14], [0.55, 0.45],\n",
    "              [0.49, 0.51], [0.99, 0.01], [0.02, 0.89], [0.31, 0.47],\n",
    "              [0.55, 0.29], [0.87, 0.76], [0.63, 0.24]])\n",
    "\n",
    "# Output threat score modeled as a vector\n",
    "y = torch.tensor([-0.8, -0.97, 0.89, -0.67, 0.97, 0.72, -0.83, 0.00, 0.00,\n",
    "              0.00, -0.09, -0.22, -0.16, 0.63, 0.37])\n",
    "A = torch.column_stack((X, torch.ones(15)))\n",
    "\n",
    "# Column stack will add an additional column of 1s to the training\n",
    "# dataset to represent the coefficient of the bias\n",
    "w = torch.matmul(pseudo_inverse(A), y)\n",
    "\n",
    "print(\"The solution is {}\\n\"\n",
    "      \"Note that this is almost equal to [1.0, 1.0, -1.0])\".format(w))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
